import pinecone
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from pinecone import Pinecone, ServerlessSpec

# 初始化 Pinecone 客户端
api_key = "c3e93ced-9abd-451a-afcd-73b1e5a9703e"
pc = Pinecone(api_key=api_key)

# 创建或连接到索引
index_name = "mnist-index"
if index_name not in pc.list_indexes().names():
    pc.create_index(
        name=index_name,
        dimension=64,  # 对于手写数字数据集，每个样本有64个特征
        metric='euclidean',  # 使用欧几里得距离进行相似度计算
        spec=ServerlessSpec(cloud='aws', region='us-east-1')  # 使用正确的区域和云提供商
    )

index = pc.Index(index_name)

# 创建一个手写数字 3 的图像
digit_3 = np.array(
    [[0, 0, 0, 255, 0, 0, 0, 0],
     [0, 0, 0, 255, 0, 0, 0, 0],
     [0, 0, 0, 255, 0, 0, 0, 0],
     [0, 0, 0, 255, 0, 0, 0, 0],
     [0, 0, 0, 255, 0, 0, 0, 0],
     [0, 0, 0, 255, 0, 0, 0, 0],
     [0, 0, 0, 255, 0, 0, 0, 0],
     [0, 0, 0, 255, 0, 0, 0, 0]]
)

# 将图像像素值从 0-255 的范围缩放到 0-16 的范围
digit_3_flatten = (digit_3 / 255.0) * 16
query_data = digit_3_flatten.ravel().tolist()

# 执行 Pinecone 查询
results = index.query(
    vector=query_data,
    top_k=11,
    include_metadata=True
)

# 从搜索结果中提取每个匹配项的标签
labels = [match['metadata']['label'] for match in results['matches']]

# 打印每个匹配结果的详细信息
for match, label in zip(results['matches'], labels):
    print(f"id: {match['id']}, distance: {match['score']}, label: {label}")

# 使用投票机制确定最终的分类结果
final_prediction = Counter(labels).most_common(1)[0][0]

# 使用 matplotlib 显示查询图像和预测结果
plt.imshow(digit_3, cmap='gray')
plt.title(f"Predicted digit: {final_prediction}", size=15)
plt.axis('off')
plt.show()